package dev.langchain4j.classification;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.CosineSimilarity;
import dev.langchain4j.store.embedding.RelevanceScore;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static dev.langchain4j.internal.ValidationUtils.ensureBetween;
import static dev.langchain4j.internal.ValidationUtils.ensureGreaterThanZero;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static java.util.Comparator.comparingDouble;
import static java.util.stream.Collectors.toList;

/**
 * A {@link TextClassifier} that uses an {@link EmbeddingModel} and predefined examples to perform classification.
 * Classification is done by comparing the embedding of the text being classified with the embeddings of predefined examples.
 * The classification quality improves with a greater number of examples for each label.
 * Examples can be easily generated with the help of an LLM.
 * <p>
 * Example:
 * <pre>{@code
 * enum Sentiment {
 *     POSITIVE, NEUTRAL, NEGATIVE
 * }
 *
 *  Map<Sentiment, List<String>> examples = Map.of(
 *     POSITIVE, List.of("This is great!", "Wow, awesome!"),
 *     NEUTRAL,  List.of("Well, it's fine", "It's ok"),
 *     NEGATIVE, List.of("It is pretty bad", "Worst experience ever!")
 * );
 *
 * EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel();
 *
 * TextClassifier<Sentiment> classifier = new EmbeddingModelTextClassifier<>(embeddingModel, examples);
 *
 * List<Sentiment> sentiments = classifier.classify("Awesome!");
 * System.out.println(sentiments); // [POSITIVE]
 * }</pre>
 *
 * @param <L> The type of the label (e.g., String, Enum, etc.)
 */
public class EmbeddingModelTextClassifier<L> implements TextClassifier<L> {

    private final EmbeddingModel embeddingModel;
    private final Map<L, List<Embedding>> exampleEmbeddingsByLabel;
    private final int maxResults;
    private final double minScore;
    private final double meanToMaxScoreRatio;

    /**
     * Creates a classifier with the default values for {@link #maxResults} (1), {@link #minScore} (0)
     * and {@link #meanToMaxScoreRatio} (0.5).
     *
     * @param embeddingModel  The embedding model used for embedding both the examples and the text to be classified.
     * @param examplesByLabel A map containing examples of texts for each label.
     *                        The more examples, the better. Examples can be easily generated by the LLM.
     */
    public EmbeddingModelTextClassifier(EmbeddingModel embeddingModel,
                                        Map<L, ? extends Collection<String>> examplesByLabel) {
        this(embeddingModel, examplesByLabel, 1, 0, 0.5);
    }

    /**
     * Creates a classifier.
     *
     * @param embeddingModel      The embedding model used for embedding both the examples and the text to be classified.
     * @param examplesByLabel     A map containing examples of texts for each label.
     *                            The more examples, the better. Examples can be easily generated by the LLM.
     * @param maxResults          The maximum number of labels to return for each classification.
     * @param minScore            The minimum similarity score required for classification, in the range [0..1].
     *                            Labels scoring lower than this value will be discarded.
     * @param meanToMaxScoreRatio A ratio, in the range [0..1], between the mean and max scores used for calculating
     *                            the final score.
     *                            During classification, the embeddings of examples for each label are compared to
     *                            the embedding of the text being classified.
     *                            This results in two metrics: the mean and max scores.
     *                            The mean score is the average similarity score for all examples associated with a given label.
     *                            The max score is the highest similarity score, corresponding to the example most
     *                            similar to the text being classified.
     *                            A value of 0 means that only the mean score will be used for ranking labels.
     *                            A value of 0.5 means that both scores will contribute equally to the final score.
     *                            A value of 1 means that only the max score will be used for ranking labels.
     */
    public EmbeddingModelTextClassifier(EmbeddingModel embeddingModel,
                                        Map<L, ? extends Collection<String>> examplesByLabel,
                                        int maxResults,
                                        double minScore,
                                        double meanToMaxScoreRatio) {
        this.embeddingModel = ensureNotNull(embeddingModel, "embeddingModel");
        ensureNotNull(examplesByLabel, "examplesByLabel");

        this.exampleEmbeddingsByLabel = new HashMap<>();
        examplesByLabel.forEach((label, examples) ->
                exampleEmbeddingsByLabel.put(label, embeddingModel.embedAll(
                        examples.stream()
                                .map(TextSegment::from)
                                .collect(toList())).content()
                )
        );

        this.maxResults = ensureGreaterThanZero(maxResults, "maxResults");
        this.minScore = ensureBetween(minScore, 0.0, 1.0, "minScore");
        this.meanToMaxScoreRatio = ensureBetween(meanToMaxScoreRatio, 0.0, 1.0, "meanToMaxScoreRatio");
    }

    @Override
    public ClassificationResult<L> classifyWithScores(String text) {

        Embedding textEmbedding = embeddingModel.embed(text).content();

        List<ScoredLabel<L>> scoredLabels = new ArrayList<>();
        exampleEmbeddingsByLabel.forEach((label, exampleEmbeddings) -> {

            double meanScore = 0;
            double maxScore = 0;
            for (Embedding exampleEmbedding : exampleEmbeddings) {
                double cosineSimilarity = CosineSimilarity.between(textEmbedding, exampleEmbedding);
                double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
                meanScore += score;
                maxScore = Math.max(score, maxScore);
            }
            meanScore /= exampleEmbeddings.size();

            double aggregateScore = aggregatedScore(meanScore, maxScore);
            if (aggregateScore >= minScore) {
                scoredLabels.add(new ScoredLabel<>(label, aggregateScore));
            }
        });

        return new ClassificationResult<>(
                scoredLabels.stream()
                        // sorting in descending order to return highest score first
                        .sorted(comparingDouble(classificationResult -> 1 - classificationResult.score()))
                        .limit(maxResults)
                        .collect(toList())
        );
    }

    private double aggregatedScore(double meanScore, double maxScore) {
        return (meanToMaxScoreRatio * meanScore) + ((1 - meanToMaxScoreRatio) * maxScore);
    }
}
